package com.micro.platform.gateway.filter.encryption;


import com.alibaba.cloud.commons.io.IOUtils;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.JSONValidator;
import com.micro.platform.gateway.constant.GatewayConstants;
import com.micro.platform.gateway.properties.LocalGatewayProperties;
import com.micro.platform.gateway.service.TokenService;
import com.micro.platform.oauth.dto.TokenInfoDTO;
import com.micro.platform.starter.utils.NumberUtils;
import com.micro.platform.starter.utils.RSACoderUtil;
import com.micro.platform.starter.utils.StringUtils;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.NettyWriteResponseFilter;
import org.springframework.core.Ordered;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.netty.http.client.HttpClientResponse;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.StringWriter;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.zip.GZIPInputStream;

import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CLIENT_RESPONSE_ATTR;

@Component
public class GlobalResponseBodyEncodeFilter implements GlobalFilter, Ordered {

    private Logger log = LoggerFactory.getLogger(GlobalResponseBodyEncodeFilter.class);

    @Value("${encryption.switch:false}")
    private boolean encryptionSwitch;

    @Autowired
    private TokenService tokenService;

    @Autowired
    private LocalGatewayProperties localGatewayProperties;

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        if(encryptionSwitch){
            ServerHttpRequest request = exchange.getRequest();
            if((request.getMethod() == HttpMethod.POST
                    || request.getMethod() == HttpMethod.PUT
                    || request.getMethod() == HttpMethod.PATCH)){
                String uri = request.getURI().getRawPath();
                ServerHttpResponse originalResponse = exchange.getResponse();
                DataBufferFactory bufferFactory = originalResponse.bufferFactory();
                ServerHttpResponseDecorator decoratedResponse = new ServerHttpResponseDecorator(originalResponse) {
                    @Override
                    public HttpHeaders getHeaders() {
                        String publicKey = exchange.getAttributeOrDefault(GatewayConstants.GATEWAY_FILTER_PUBLIC_KEY, null);
                        if (StringUtils.isEmpty(publicKey)) {
                            return super.getHeaders();
                        }
                        Object o = exchange.getAttributes().get(CLIENT_RESPONSE_ATTR);
                        HttpHeaders headers = new HttpHeaders();
                        if(o != null && o instanceof HttpClientResponse){
                            HttpClientResponse httpClientResponse = (HttpClientResponse)o;
                            httpClientResponse.responseHeaders().forEach(
                                    entry -> headers.add(entry.getKey(), entry.getValue()));
                        }
                        headers.set(HttpHeaders.TRANSFER_ENCODING, "chunked");
                        headers.putAll(super.getHeaders());
                        return headers;
                    }
                    @Override
                    public Mono<Void> writeWith(Publisher<? extends DataBuffer> body) {
                        if (body instanceof Flux) {
                            Flux<? extends DataBuffer> fluxBody = (Flux<? extends DataBuffer>) body;
                            return super.writeWith(fluxBody.buffer().map(dataBuffers -> {
                                DataBufferFactory dataBufferFactory = new DefaultDataBufferFactory();
                                DataBuffer join = dataBufferFactory.join(dataBuffers);
                                byte[] content = new byte[join.readableByteCount()];
                                join.read(content);
                                //释放掉内存
                                DataBufferUtils.release(join);
                                if (isloginPath(uri)) {
                                    String json = new String(content, Charset.forName("UTF-8"));
                                    analysisPublicKey(exchange, json);
                                }
                                String publicKey = exchange.getAttributeOrDefault(GatewayConstants.GATEWAY_FILTER_PUBLIC_KEY, null);
                                if(StringUtils.isEmpty(publicKey)){
                                    return bufferFactory.wrap(content);
                                }
                                // 私钥加密返回的响应体
                                byte[] encodedBytes;
                                List<String> encodings = this.getHeaders().get(HttpHeaders.CONTENT_ENCODING);
                                boolean gzipFlag = encodings != null && encodings.contains("gzip");
                                if (gzipFlag) {
                                    GZIPInputStream gzipInputStream = null;
                                    try {
                                        gzipInputStream = new GZIPInputStream(new ByteArrayInputStream(content), content.length);
                                        StringWriter writer = new StringWriter();
                                        IOUtils.copy(gzipInputStream, writer, StandardCharsets.UTF_8);
                                        content = writer.toString().getBytes(StandardCharsets.UTF_8);
                                    } catch (IOException e) {
                                        log.error("====Gzip IO error", e);
                                        throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "Gzip 解压错误");
                                    } finally {
                                        if (gzipInputStream != null) {
                                            try {
                                                gzipInputStream.close();
                                            } catch (IOException e) {
                                                log.error("===Gzip IO close error", e);
                                                throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "Gzip 流关闭错误");
                                            }
                                        }
                                    }
                                }
                                log.info("GlobalResponseBodyEncodeFilter filter ,gzip resp is : {}", new String(content));
                                try {
                                    encodedBytes = RSACoderUtil.encryptByPublicKey(content, publicKey);
                                } catch (Exception e) {
                                    e.printStackTrace();
                                    log.error("加密异常：" + e.getMessage());
                                    throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, "响应体加密异常");
                                }
                                String str = RSACoderUtil.encryptBASE64(encodedBytes);
                                byte[] bytes = str.getBytes(StandardCharsets.UTF_8);
                                return bufferFactory.wrap(bytes);
                            }));
                        }
                        return super.writeWith(body);
                    }
                };
                return chain.filter(exchange.mutate().response(decoratedResponse).build());
            }
            return chain.filter(exchange);
        }
        return chain.filter(exchange);
    }

    private void analysisPublicKey(ServerWebExchange exchange, String json) {
        String accessToken = analysisAccessToken(json);
        if(StringUtils.isEmpty(accessToken)){
            return;
        }
        TokenInfoDTO tokenInfoDTO = tokenService.analysisAccessToken(accessToken);
        // 密钥放入属性中向下传递
        exchange.getAttributes().put(GatewayConstants.GATEWAY_FILTER_PUBLIC_KEY, tokenInfoDTO.getPublicKey());
    }

    private String analysisAccessToken(String json) {
        if(JSONValidator.from(json).validate()){
            JSONObject jsonObject = JSONObject.parseObject(json);
            Integer code = jsonObject.getInteger("code");
            if(NumberUtils.equalsInteger(HttpStatus.OK.value(), code)){
                JSONObject data = jsonObject.getJSONObject("data");
                if(data != null){
                    return data.getString("accessToken");
                }
            }
        }
        return null;
    }

    private boolean isloginPath(String uri) {
        return localGatewayProperties.getApi().getLoginPath().contains(uri);
    }
    @Override
    public int getOrder() {
        return NettyWriteResponseFilter.WRITE_RESPONSE_FILTER_ORDER - 10;
    }
}